package shell

import (
	"errors"
	"fmt"
	"github.com/fatih/color"
	"github.com/pkg/sftp"
	"gosh/pkg/connect"
	"gosh/pkg/handler"
	"io"
	"io/fs"
	"os"
	"path"
	"strings"
	"time"
)

func PullFiles() {
	if err := initJumpServer(); err != nil {
		color.Red("init bastion failed,err:%s\n", err.Error())
		return
	}
	handler.Handler(SSHConfig.Fork, SSHConfig.Hosts, SSHConfig.PrintStatus, SSHConfig.PrintFields, pull)
}

func pull(ip string) (stdout, stderr []byte, err error) {
	host := hostMap[ip]
	jumpserver := jumpserverMap[host.BastionIP]

	sshHost := connect.NewHost(ip, host.SSHUser, host.SSHPassword, host.SSHKeyContent,
		host.SSHKeyPassphrase, host.SSHPort, jumpserver, host.SSHTimeout)
	if err = sshHost.OpenSftp(); err != nil {
		return nil, nil, err
	}
	defer sshHost.Close()
	cmd := SSHConfig.Args
	return pullFiles(ip, sshHost.SftpClient, cmd)
}

type remoteFileInfo struct {
	srcPath  string
	destPath string
	mode     fs.FileMode
	size     int64
}

// 遍历远程目录下的文件
func walk(client *sftp.Client, src, dest string) ([]*remoteFileInfo, []error) {
	ret := make([]*remoteFileInfo, 0, 1)
	errs := make([]error, 0, 1)
	// 如果是文件 remotePath --> path.base(remotePath)
	baseFile := path.Base(src)
	walk := client.Walk(src)
	for walk.Step() {
		if err := walk.Err(); err != nil {
			newErr := errors.New(fmt.Sprintf("walk faield,err:%s\n", err.Error()))
			errs = append(errs, newErr)
			continue
		}
		if !walk.Stat().IsDir() {
			sourcePath := walk.Path()
			// 以 remotePath 开头的路径
			if sourcePath != src {
				baseFile = strings.Replace(sourcePath, src, "", 1)
				if strings.HasPrefix(baseFile, "/") {
					baseFile = baseFile[1:]
				}
				baseFile = path.Join(path.Base(src), baseFile)
			}
			f := &remoteFileInfo{
				srcPath:  sourcePath,
				destPath: path.Join(dest, baseFile),
				mode:     walk.Stat().Mode(),
				size:     walk.Stat().Size(),
			}
			ret = append(ret, f)
		}
	}
	return ret, errs
}

// 下载文件
func pullFiles(ip string, client *sftp.Client, args []string) ([]byte, []byte, error) {
	var pullStdout = make([]byte, 0, 1024)
	var pullStderr = make([]byte, 0, 1024)

	localDir := path.Join(strings.TrimSuffix(args[len(args)-1], "/"), ip)

	for _, remoteFile := range args[:len(args)-1] {
		size := int64(0)
		start := time.Now().UnixNano()

		fileInfos, errs := walk(client, remoteFile, localDir)
		for _, err := range errs {
			pullStderr = append(pullStderr, []byte(err.Error())...)
		}

		for _, fileInfo := range fileInfos {
			err := pullFile(fileInfo, client)
			if err != nil {
				pullStderr = append(pullStderr, []byte(fmt.Sprintf("pull %s failed, error:%s\n", fileInfo.srcPath, err.Error()))...)
				continue
			}
			size += fileInfo.size
		}
		end := time.Now().UnixNano()
		if size > 0 {
			pullStdout = append(pullStdout, resFormat(size, remoteFile, end-start)...)
		}
	}
	return pullStdout, pullStderr, nil
}

// pull 创建文件，拷贝文件并修改权限
func pullFile(fileInfo *remoteFileInfo, client *sftp.Client) error {
	remoteFd, err := client.Open(fileInfo.srcPath)
	if err != nil {
		return err
	}

	baseDir := path.Dir(fileInfo.destPath)
	if _, err = os.Stat(baseDir); os.IsNotExist(err) {
		if err = os.MkdirAll(baseDir, os.FileMode(0755)); err != nil {
			return err
		}
	}

	localFd, err := os.OpenFile(fileInfo.destPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, fileInfo.mode)
	if err != nil {
		return err
	}
	defer func() { _ = localFd.Close() }()

	_, err = io.Copy(localFd, remoteFd)
	if err != nil {
		return err
	}
	return nil
}
